# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from test_collective_multi_nodes import (
    TestCollectiveAPIRunnerBase,
    runtime_main,
)

import paddle
from paddle import nn
from paddle.distributed import fleet


def weight_init(mp, shape, col=True, seed=1024):
    rng = np.random.RandomState(seed)
    w = rng.normal(0, 0.02, size=shape)
    if mp is None:
        _w = w
    else:
        if col:
            step = shape[1] // mp.nranks
            _w = w[:, mp.rank * step : mp.rank * step + step]
        else:
            step = shape[0] // mp.nranks
            _w = w[mp.rank * step : mp.rank * step + step, :]
    return paddle.nn.initializer.Assign(_w)


class Criterion(nn.Layer):
    def __init__(self):
        super().__init__()
        self.loss_func = nn.MSELoss(reduction="mean")

    def forward(self, pred, label):
        loss = self.loss_func(pred, label)
        return loss


class ModelPipeline(fleet.meta_parallel.PipelineLayer):
    def __init__(self, hcg):
        paddle.seed(1024)
        dp_linear = nn.Linear(32, 128)
        self.layers_pp = []
        self.topology = hcg.topology()
        self.layers_pp.append(dp_linear)
        mp = hcg.get_model_parallel_group()
        for i in range(6):
            if mp is not None and mp.nranks > 1:
                mp_linear_1 = fleet.meta_parallel.ColumnParallelLinear(
                    128,
                    512,
                    weight_attr=weight_init(mp, (128, 512), True, 1204 + i),
                    has_bias=True,
                    gather_output=False,
                )
                mp_linear_2 = fleet.meta_parallel.RowParallelLinear(
                    512,
                    128,
                    weight_attr=weight_init(mp, (512, 128), False, 2012 + i),
                    has_bias=True,
                    input_is_parallel=True,
                )
            else:
                mp_linear_1 = nn.Linear(
                    128,
                    512,
                    weight_attr=weight_init(None, (128, 512), True, 1204 + i),
                )
                mp_linear_2 = nn.Linear(
                    512,
                    128,
                    weight_attr=weight_init(None, (512, 128), True, 2012 + i),
                )
            act = nn.ReLU6()
            layer_seq = nn.Sequential(mp_linear_1, mp_linear_2, act)
            self.layers_pp.append(layer_seq)

        out = nn.Linear(128, 32)
        self.layers_pp.append(out)
        super().__init__(
            layers=self.layers_pp, loss_fn=Criterion(), topology=self.topology
        )


class Model(nn.Layer):
    def __init__(self, hcg):
        super().__init__()
        paddle.seed(1024)
        dp_linear = nn.Linear(32, 128)
        self.layers_pp = []
        self.layers_pp.append(dp_linear)
        mp = hcg.get_model_parallel_group() if hcg else None
        for i in range(6):
            if mp is not None and mp.nranks > 1:
                mp_linear_1 = fleet.meta_parallel.ColumnParallelLinear(
                    128,
                    512,
                    weight_attr=weight_init(mp, (128, 512), True, 1204 + i),
                    has_bias=True,
                    gather_output=False,
                )
                mp_linear_2 = fleet.meta_parallel.RowParallelLinear(
                    512,
                    128,
                    weight_attr=weight_init(mp, (512, 128), False, 2012 + i),
                    has_bias=True,
                    input_is_parallel=True,
                )
            else:
                mp_linear_1 = nn.Linear(
                    128,
                    512,
                    weight_attr=weight_init(None, (128, 512), True, 1204 + i),
                )
                mp_linear_2 = nn.Linear(
                    512,
                    128,
                    weight_attr=weight_init(None, (512, 128), True, 2012 + i),
                )
            act = nn.ReLU6()
            layer_seq = nn.Sequential(mp_linear_1, mp_linear_2, act)
            self.layers_pp.append(layer_seq)

        out = nn.Linear(128, 32)
        self.layers_pp.append(out)
        self.layers = nn.Sequential(*self.layers_pp)

    def forward(self, x):
        return self.layers(x)


class TestDygraphHybridFp16(TestCollectiveAPIRunnerBase):
    def __init__(self):
        pass

    def check_pass(self, *args, **kwargs):
        from common import init_parallel_env

        import paddle
        from paddle.distributed import fleet

        hcg = init_parallel_env("DP4-MP2-PP2-SH1-O1", 64)
        pp_degree = hcg.get_pipe_parallel_world_size()
        import numpy as np

        crit = Criterion()
        if pp_degree <= 1:
            model = Model(hcg)
        else:
            model = ModelPipeline(hcg)

        model_base = Model(None)

        optimizer = paddle.optimizer.Adam(
            learning_rate=0.01,
            parameters=model.parameters(),
            multi_precision=True,
        )
        optimizer_base = paddle.optimizer.Adam(
            learning_rate=0.01, parameters=model_base.parameters()
        )

        scaler = paddle.amp.GradScaler(init_loss_scaling=4096)
        scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(
            models=model, level='O2', save_dtype='float32'
        )

        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

        loss_hybrid_arr = []
        loss_base_arr = []
        x = paddle.to_tensor(np.random.random((16, 32))).astype("float32")
        y = paddle.to_tensor(np.random.random((16, 32))).astype("float32")

        for _ in range(2):
            if pp_degree > 1:
                with paddle.amp.auto_cast(True, level='O2'):
                    loss = model.train_batch(
                        [x, y], optimizer=optimizer, scaler=scaler
                    )
            else:
                with paddle.amp.auto_cast(True, level='O2'):
                    output = model(x)
                    loss = crit(output, y)
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
                optimizer.clear_grad()

            # baseline loss
            with paddle.amp.auto_cast(True, level='O2'):
                output_base = model_base(x)
                loss_base = crit(output_base, y)
            loss_base.backward()
            optimizer_base.step()
            optimizer_base.clear_grad()

            loss_base_arr.append(loss_base.numpy())
            loss_hybrid_arr.append(loss)
        np.testing.assert_allclose(
            loss_base_arr, loss_hybrid_arr, rtol=1e-3, atol=1e-3
        )


if __name__ == "__main__":
    runtime_main(TestDygraphHybridFp16, "dpppmp")
